# Modification 2020 RangiLyu
# Copyright 2018-2019 Open-MMLab.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from ...loss.iou_loss import bbox_overlaps
from .assign_result import AssignResult
from .base_assigner import BaseAssigner


class ATSSAssigner(BaseAssigner):
    """Assign a corresponding gt bbox or background to each bbox.

    Each proposals will be assigned with `0` or a positive integer
    indicating the ground truth index.

    - 0: negative sample, no assigned gt
    - positive integer: positive sample, index (1-based) of assigned gt

    Args:
        topk (float): number of bbox selected in each level
    """

    def __init__(self, topk):
        self.topk = topk  # 9

    # https://github.com/sfzhang15/ATSS/blob/master/atss_core/modeling/rpn/atss/loss.py

    def assign(
        self, bboxes, num_level_bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None
    ):
        """Assign gt to bboxes.

        The assignment is done in following steps

        1. compute iou between all bbox (bbox of all pyramid levels) and gt
        2. compute center distance between all bbox and gt
        3. on each pyramid level, for each gt, select k bbox whose center
           are closest to the gt center, so we total select k*l bbox as
           candidates for each gt
        4. get corresponding iou for the these candidates, and compute the
           mean and std, set mean + std as the iou threshold
        5. select these candidates whose iou are greater than or equal to
           the threshold as postive
        6. limit the positive sample's center in gt


        Args:
            bboxes (Tensor): Bounding boxes to be assigned, shape(n, 4).
            num_level_bboxes (List): num of bboxes in each level
            gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
            gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
                labelled as `ignored`, e.g., crowd boxes in COCO.
            gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).

        Returns:
            :obj:`AssignResult`: The assign result.
        """
        INF = 100000000
       # 注意：该函数中的bbox不是最后输出的bbox坐标，而是anchor的坐标，即每层anchor的左上右下坐标
        bboxes = bboxes[:, :4]  # 这步感觉有些多余
        num_gt, num_bboxes = gt_bboxes.size(0), bboxes.size(0)  # gt 的数目、anchor 的个数即 (28x40 + 14x20 + 7x10)
        
        # compute iou between all bbox and gt
        overlaps = bbox_overlaps(bboxes, gt_bboxes)  # iou, shape = (num_bboxes, num_gt),元素为每个gt与每个anchor自身的iou值
        #print(overlaps.shape)
        
        # assign 0 by default 新建数组，用来记录每个 anchor 对应的 gt
        assigned_gt_inds = overlaps.new_full((num_bboxes,), 0, dtype=torch.long)
        #print(assigned_gt_inds.shape)
        if num_gt == 0 or num_bboxes == 0:
            # No ground truth or boxes, return empty assignment
            max_overlaps = overlaps.new_zeros((num_bboxes,))
            if num_gt == 0:
                # No truth, assign everything to background 没有gt 说明图片上没有背景 置为 0 
                assigned_gt_inds[:] = 0
            if gt_labels is None:
                assigned_labels = None
            else:
                assigned_labels = overlaps.new_full((num_bboxes,), -1, dtype=torch.long)
            return AssignResult(
                num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels
            )

        # compute center distance between all bbox and gt
        gt_cx = (gt_bboxes[:, 0] + gt_bboxes[:, 2]) / 2.0
        gt_cy = (gt_bboxes[:, 1] + gt_bboxes[:, 3]) / 2.0
        gt_points = torch.stack((gt_cx, gt_cy), dim=1)  # shape = (num_gt, 2)
        #print(gt_points.shape)
        bboxes_cx = (bboxes[:, 0] + bboxes[:, 2]) / 2.0  # shape = (num_bboxes, )
        bboxes_cy = (bboxes[:, 1] + bboxes[:, 3]) / 2.0
        bboxes_points = torch.stack((bboxes_cx, bboxes_cy), dim=1)  # shape = (num_bboxes, 2)
        #print(bboxes_points.shape)
        # 得到每个 anchor 与 gt 的中心坐标欧氏距离
        distances = (  # ((num_bboxes, 2)->(num_bboxes, 1, 2) - (num_gt, 2)->(1, num_gt, 2)) -> (num_bboxes, num_gt, 2) -> (num_bboxes, num_gt)
            (bboxes_points[:, None, :] - gt_points[None, :, :]).pow(2).sum(-1).sqrt() 
        )
        #print(distances.shape)
        # Selecting candidates based on the center distance
        candidate_idxs = []
        start_idx = 0
        for level, bboxes_per_level in enumerate(num_level_bboxes):  # num_level_bboxes = [28x40, 14x20, 7x10]
            # on each pyramid level, for each gt,
            # select k bbox whose center are closest to the gt center
            end_idx = start_idx + bboxes_per_level
            distances_per_level = distances[start_idx:end_idx, :]
            selectable_k = min(self.topk, bboxes_per_level)  # 避免小图片输入时，feature map长度小于topk的情况
            # 求出了离每个 gt 最近的 topk 个 anchor 的索引值,即行索引, shape = (topk, num_gt)
            _, topk_idxs_per_level = distances_per_level.topk(
                selectable_k, dim=0, largest=False
            )  
            candidate_idxs.append(topk_idxs_per_level + start_idx)  # 加上 start_idx 是因为它上面的输出是相对索引，不是绝对索引
            start_idx = end_idx
        candidate_idxs = torch.cat(candidate_idxs, dim=0)  # (topk * 3, num_gt)
        #print(candidate_idxs.shape)
        # get corresponding iou for the these candidates, and compute the
        # mean and std, set mean + std as the iou threshold
        candidate_overlaps = overlaps[candidate_idxs, torch.arange(num_gt)]
        overlaps_mean_per_gt = candidate_overlaps.mean(0)  # 同一个gt，所有候选anchor与其的iou值平均， shape = (num_gt, )
        overlaps_std_per_gt = candidate_overlaps.std(0)  # 同一个gt，所有候选anchor与其的iou值标准差
        overlaps_thr_per_gt = overlaps_mean_per_gt + overlaps_std_per_gt
        #print("mt = ", overlaps_mean_per_gt, overlaps_std_per_gt)
        is_pos = candidate_overlaps >= overlaps_thr_per_gt[None, :]  # 得到 iou 大于阈值的 anchor 索引, bool 值 (topk * 3, num_gt)
        #print(is_pos.shape)
        # limit the positive sample's center in gt
        for gt_idx in range(num_gt):  # 给每个gt索引范围给拉开，主要是方便下面的计算，查看 L140代码及注释可知
            candidate_idxs[:, gt_idx] += gt_idx * num_bboxes  # 类似于 yolo v5 对 bbox 加偏移，更加方便计算 iou 值
        ep_bboxes_cx = (  # (num_bboxes, ) -> (1, num_bboxes) -> (num_gt, num_bboxes) -> (num_gt * num_bboxes)
            bboxes_cx.view(1, -1).expand(num_gt, num_bboxes).contiguous().view(-1)  # 注意它的长度是 num_gt * num_bboxes
        )
        #print(ep_bboxes_cx.shape)
        ep_bboxes_cy = (  
            bboxes_cy.view(1, -1).expand(num_gt, num_bboxes).contiguous().view(-1)
        )
        candidate_idxs = candidate_idxs.view(-1)  # (topk * 3 * num_gt)

        # calculate the left, top, right, bottom distance between positive
        # bbox center and gt side
        l_ = ep_bboxes_cx[candidate_idxs].view(-1, num_gt) - gt_bboxes[:, 0]  # (num_gt * num_bboxes) -> (topk * 3, num_gt) - (num_gt, ) -> (topk * 3, num_gt)
        t_ = ep_bboxes_cy[candidate_idxs].view(-1, num_gt) - gt_bboxes[:, 1]
        r_ = gt_bboxes[:, 2] - ep_bboxes_cx[candidate_idxs].view(-1, num_gt)
        b_ = gt_bboxes[:, 3] - ep_bboxes_cy[candidate_idxs].view(-1, num_gt)
        is_in_gts = torch.stack([l_, t_, r_, b_], dim=1).min(dim=1)[0] > 0.01  # anchor中心离gt每条边都要大于 0.01，否则会被丢弃 (topk * 3, num_gt)
        is_pos = is_pos & is_in_gts  # iou大于阈值 且 anchor 中心要在 gt 内部
        #print(ep_bboxes_cx[candidate_idxs].view(-1, num_gt).shape)
        #print(gt_bboxes[:, 0].shape)
        #print(ep_bboxes_cx[candidate_idxs].shape)
        
        # if an anchor box is assigned to multiple gts,
        # the one with the highest IoU will be selected.
        overlaps_inf = torch.full_like(overlaps, -INF).t().contiguous().view(-1)  # (num_bboxes, num_gt) -> (num_gt, num_bboxes) -> (num_gt * num_bboxes)
        
        index = candidate_idxs.view(-1)[is_pos.view(-1)]  # 结果三重筛选后的 anchor 索引
        overlaps_inf[index] = overlaps.t().contiguous().view(-1)[index]  # 符合条件的地方赋IOU值，不符合条件的地方是负无穷
        overlaps_inf = overlaps_inf.view(num_gt, -1).t()  # 又转回 (num_bboxes, num_gt)
        
        max_overlaps, argmax_overlaps = overlaps_inf.max(dim=1)  # 返回每一行的最大值，也是每个 anchor 与某个 gt 的最大 iou 值
        #print(assigned_gt_inds[max_overlaps != -INF])
        assigned_gt_inds[max_overlaps != -INF] = (  # anchor 的绝大多数肯定是负无穷，不为负无穷的anchor说明有gt与之匹配，它是我们的正样本
            argmax_overlaps[max_overlaps != -INF] + 1  # 加1是因为0是背景，所以我们的前景目标要从1开始算。这里给每个 anchor 匹配 IOU 最大的 gt。
        )  # 输出的这个数组里，0 是负样本，是背景。大于0是正样本，是我们检测的目标。
        #print(assigned_gt_inds[max_overlaps != -INF])
        if gt_labels is not None:
            assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1)
            pos_inds = torch.nonzero(assigned_gt_inds > 0, as_tuple=False).squeeze()
            if pos_inds.numel() > 0:  # 含有元素数量
                assigned_labels[pos_inds] = gt_labels[assigned_gt_inds[pos_inds] - 1]  # 减一是因为 gt_labels 里不用分前景背景
        else:
            assigned_labels = None
        return AssignResult(
            num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels
        )  # 注意这是类初始化
